import json
from pathlib import Path
from typing import NamedTuple

import numpy as np
import pandas as pd
import preprocess
from preprocess import utils
import qmt
import tree


class Config(NamedTuple):

    exp_folder: str

    # if experiment is not present, then alignment will be skipped
    # this motion phase should not contain too crazy motion; it will just
    # make the acceleration estimate "hide" gravity
    alignment_timings_motion: dict[str, str] = {
        "S_04": "slow",
        "S_06": "slow1",
        "S_07": "slow_fast_mix",
        "S_08": "slow1",
        "S_09": "slow_global",
        "S_10": "pickandplace",
        "S_12": "slow1",
        "S_13": "slow_fast_mix",
        "S_14": "slow",
        "S_15": "slow_global",
        "S_16": "gait_slow",
        "T_01": ("slow", "shaking"),
    }

    valid_exp_ids: list[str] = [
        "S_06",
        "S_07",
        "S_08",
        "S_09",
        "S_10",
        "S_12",
        "S_13",
        "S_14",
        "S_15",
        "S_16",
        "T_01",
    ]

    imus: list[str] = ["imu_rigid", "imu_flex"]
    segments: list[str] = ["seg1", "seg2", "seg3", "seg4", "seg5"]

    setup_file = Path(__file__).parent.joinpath("preprocess/setup.json")

    def path_optitrack_file(self, exp_id: str) -> str:
        return f"{self.exp_folder}/{exp_id}/optitrack/{exp_id}.csv"

    def path_imu_folder(self, exp_id: str) -> str:
        return f"{self.exp_folder}/{exp_id}/imu"


def _arm_or_gait(exp_id: str) -> str:
    if exp_id in ["S_04", "S_06", "S_07", "S_08", "S_09", "S_10"]:
        return "arm"
    elif exp_id in ["S_12", "S_13", "S_14", "S_15", "S_16", "T_01"]:
        return "gait"
    else:
        raise Exception(f"`exp_id` is not valid {exp_id}")


def _alignment_timings(exp_id: str, config: Config) -> tuple[float, float]:
    timings = preprocess.load_timings(exp_id)
    timings_list = list(timings.keys())
    motion = config.alignment_timings_motion[exp_id]
    if isinstance(motion, str):
        next_motion = timings_list[timings_list.index(motion) + 1]
    else:
        motion, next_motion = motion
    return timings[motion], timings[next_motion]


def _marker_closest_to_rigid_imu(seg: str, config: Config) -> int:
    with open(config.setup_file) as f:
        marker_imu_setup = json.load(f)
    return int(marker_imu_setup[seg]["marker_closest_to_rigid_imu"])


def _get_alignment(data: dict, exp_id: str, hz_omc, hz_imu, config: Config):
    hz_alignment = 120.0
    hz_in = utils.hz_helper(
        list(data.keys()),
        hz_imu=hz_imu,
        hz_omc=hz_omc,
        imus=config.imus,
    )
    data = utils.resample(data, hz_in, hz_alignment, vecinterp_method="cubic")
    data = utils.crop_tail(data, hz_alignment)
    t1, t2 = _alignment_timings(exp_id, config)
    data = utils.crop_sequence(data, 1 / hz_alignment, t1, t2)

    acc, gyr, mag, q, pos, names = [], [], [], [], [], []
    for seg_name, seg_data in data.items():
        imu = seg_data["imu_rigid"]
        acc.append(imu["acc"])
        gyr.append(imu["gyr"])
        mag.append(imu["mag"])
        q.append(seg_data["quat"])
        pos.append(seg_data[f"marker{_marker_closest_to_rigid_imu(seg_name, config)}"])
        names.append(seg_name)

    info = qmt.alignOptImu(
        gyr, acc, mag, q, pos, rate=hz_alignment, names=names, params=dict(fast=True)
    )

    def _np_tolist(leaf):
        if isinstance(leaf, np.ndarray):
            return leaf.tolist()
        return leaf

    root = Path(__file__).parent
    alignment_infos = root.joinpath(f"alignment_infos/alignment_info_{exp_id}.json")
    alignment_infos.parent.mkdir(parents=True, exist_ok=True)
    with open(alignment_infos, "w") as file:
        json.dump(tree.map_structure(_np_tolist, info), file, indent=1)

    qEOpt2EImu_euler_deg = info["qEOpt2EImu_euler_deg"]
    qImu2Seg_euler_deg = {seg_name: dict() for seg_name in data}
    for seg_name in data:
        qImu2Seg_euler_deg[seg_name]["imu_rigid"] = info[
            f"qImu2Seg_{seg_name}_euler_deg"
        ]
    print(qImu2Seg_euler_deg)
    return qEOpt2EImu_euler_deg, qImu2Seg_euler_deg


def _to_data_split(exp_id: str, config: Config) -> dict[str, dict]:

    path_imu_folder = config.path_imu_folder(exp_id)
    path_optitrack_file = config.path_optitrack_file(exp_id)

    hz_omc = float(utils.autodetermine_optitrack_freq(path_optitrack_file))
    hz_imu = float(utils.autodetermine_imu_freq(path_imu_folder))

    # perform sync
    data, imu_sync_offset = preprocess.read_omc(
        path_marker_imu_setup_file=config.setup_file,
        path_optitrack_file=path_optitrack_file,
        path_imu_folder=path_imu_folder,
        imu_names_setup_file=config.imus,
        segment_names_setup_file=config.segments,
    )

    do_alignment = True
    if do_alignment:
        qEOpt2EImu_euler_deg, qImu2Seg_euler_deg = _get_alignment(
            data, exp_id, hz_omc, hz_imu, config
        )

        # aligned
        data, _ = preprocess.read_omc(
            path_marker_imu_setup_file=config.setup_file,
            path_optitrack_file=path_optitrack_file,
            path_imu_folder=path_imu_folder,
            imu_sync_offset=imu_sync_offset,
            qEOpt2EImu_euler_deg=qEOpt2EImu_euler_deg,
            qImu2Seg_euler_deg=qImu2Seg_euler_deg,
            imu_names_setup_file=config.imus,
            segment_names_setup_file=config.segments,
        )

    hz_in = utils.hz_helper(
        list(data.keys()),
        hz_imu=hz_imu,
        hz_omc=hz_omc,
        imus=config.imus,
    )
    # croptail
    data = utils.crop_tail(data, hz_in, strict=False)

    # split into individual motions
    data_split = {}
    timings = preprocess.load_timings(exp_id)
    timings_list = list(timings.keys())
    for motion_start, motion_stop in zip(timings_list, (timings_list[1:] + [None])):
        t1 = timings[motion_start]
        t2 = timings[motion_stop] if motion_stop is not None else None
        if t2 is not None:
            assert t2 > t1, f"t2={t2};t1={t1};timings={timings}"
        data_split[motion_start] = tree.map_structure(
            lambda d, hz: utils.crop_sequence(d, 1 / hz, t1, t2), data, hz_in
        )

    return data_split


def _add_header_to_csv_file(path: str, freq: float):

    header = f"# sampling frequency: {freq}\n"
    header += "# units are seconds/meters/radians/a.u.\n"

    with open(path, "r") as original:
        data = original.read()
    with open(path, "w") as modified:
        modified.write(header + data)


def _dump_omc_csv(path, data: dict, freq):
    df = {}
    for i in range(1, 6):
        seg = f"seg{i}"
        # markers
        for j in range(1, 5):
            for k, xyz in enumerate(["x", "y", "z"]):
                df[f"{seg}_marker{j}_{xyz}"] = data[seg][f"marker{j}"][:, k]
        # quat
        for k, wxyz in enumerate(["w", "x", "y", "z"]):
            df[f"{seg}_quat_{wxyz}"] = data[seg]["quat"][:, k]

    pd.DataFrame.from_dict(df).to_csv(path, sep=",", index=False)
    _add_header_to_csv_file(path, freq)


def _dump_imu_csv(path, data: dict, freq, imu: str):
    df = {}
    for i in range(1, 6):
        seg = f"seg{i}"
        # acc/gyr/mag
        for j, accgyrmag in enumerate(["acc", "gyr", "mag"]):
            for k, xyz in enumerate(["x", "y", "z"]):
                df[f"{seg}_{accgyrmag}_{xyz}"] = data[seg][imu][accgyrmag][:, k]

    pd.DataFrame.from_dict(df).to_csv(path, sep=",", index=False)
    _add_header_to_csv_file(path, freq)


def process_exp_id(exp_id: str, new_exp_id: int, config: Config):
    assert exp_id in config.valid_exp_ids

    data_split = _to_data_split(exp_id, config)

    new_exp_id = str(new_exp_id).rjust(2, "0")

    root = Path(__file__).parent.joinpath(
        f"dataset/{_arm_or_gait(exp_id)}/exp{new_exp_id}"
    )
    root.mkdir(exist_ok=True, parents=True)

    for i, motion in enumerate(data_split):
        i = str(i + 1).rjust(2, "0")
        folder = root.joinpath(f"motion{i}_{motion}")
        folder.mkdir(exist_ok=True)
        file_prefix = str(folder.joinpath(f"exp{new_exp_id}_motion{i}_"))

        data_motion_renamed = preprocess.rename_segments(
            data_split[motion], config.setup_file
        )

        _dump_omc_csv(
            file_prefix + "omc.csv",
            data_motion_renamed,
            freq=utils.autodetermine_optitrack_freq(config.path_optitrack_file(exp_id)),
        )

        hz_imu = utils.autodetermine_imu_freq(config.path_imu_folder(exp_id))
        _dump_imu_csv(
            file_prefix + "imu_rigid.csv",
            data_motion_renamed,
            hz_imu,
            "imu_rigid",
        )
        _dump_imu_csv(
            file_prefix + "imu_nonrigid.csv",
            data_motion_renamed,
            hz_imu,
            "imu_flex",
        )


def main():

    data_folder = Path(__file__).parent.joinpath("raw_data")
    config = Config(exp_folder=str(data_folder))
    exp_ids = config.valid_exp_ids

    for exp_id in exp_ids:
        process_exp_id(exp_id, config.valid_exp_ids.index(exp_id) + 1, config)


if __name__ == "__main__":
    main()
